import torch

A = torch.tensor([[1,2,3],[4,5,6]])
msk = torch.tensor([[True,True,False],[True,False,True]])
print(A)
A[A<5].fill_(0)
print(A)
